import os
import cv2
import shutil
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from scipy.stats import skew
from skimage import measure
import matplotlib.pyplot as plt
import plotly.graph_objects as go

base_dir = r"Path to data"
data_path = os.path.join(base_dir, "train") 
output_dir = os.path.join(data_path, "Filtering_Results_Final")
plot_dir = os.path.join(output_dir, "Plots")
cleaned_dir = os.path.join(output_dir, "Cleaned_Dataset")

for d in [output_dir, plot_dir, cleaned_dir]:
    os.makedirs(d, exist_ok=True)


def brenner_focus(img_gray):
    diff = img_gray[2:, :] - img_gray[:-2, :]
    return np.mean(diff ** 2)

def edge_coherence(img_gray):
    gx = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0)
    gy = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1)
    mag = np.sqrt(gx**2 + gy**2)
    thresh = np.percentile(mag, 75)
    return np.sum(mag > thresh) / (mag.size + 1e-9)

def edge_density(img_gray):
    edges = cv2.Canny(img_gray, 80, 160)
    return np.sum(edges > 0) / (edges.size + 1e-9)


def extract_features(img_path):
    img = cv2.imread(img_path)
    if img is None:
        return None

    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img_gray = cv2.resize(img_gray, (512, 512))
    img_fl = img_gray.astype(np.float32)

    # NSS
    mu = cv2.GaussianBlur(img_fl, (7, 7), 1.166)
    sigma = np.sqrt(np.abs(cv2.GaussianBlur(img_fl**2, (7, 7), 1.166) - mu**2))
    mscn = (img_fl - mu) / (sigma + 1.0)

    return {
        "Sharp_Lap": cv2.Laplacian(img_gray, cv2.CV_64F).var(),
        "Sharp_Tenengrad": np.mean(
            cv2.Sobel(img_gray, cv2.CV_64F, 1, 0)**2 +
            cv2.Sobel(img_gray, cv2.CV_64F, 0, 1)**2
        ),
        "Sharp_Brenner": brenner_focus(img_gray),
        "Struct_Entropy": measure.shannon_entropy(img_gray),
        "Struct_Coherence": edge_coherence(img_gray),
        "Struct_EdgeDensity": edge_density(img_gray),
        "NSS_Var": np.var(mscn),
        "Exposure_Dev": abs(np.mean(img_gray) - 127.5),
        "Exposure_Clip": (np.sum(img_gray <= 1) + np.sum(img_gray >= 254)) / img_gray.size,
        "path": img_path
    }


image_files = [
    os.path.join(r, f)
    for r, _, fs in os.walk(data_path)
    for f in fs if f.lower().endswith((".png", ".jpg", ".jpeg"))
]

records = []
for fpath in tqdm(image_files, desc="Extracting features"):
    feats = extract_features(fpath)
    if feats:
        records.append(feats)

df = pd.DataFrame(records).fillna(0)


num_cols = df.select_dtypes(include=[np.number]).columns
df[num_cols] = (df[num_cols] - df[num_cols].min()) / (
    df[num_cols].max() - df[num_cols].min() + 1e-9
)

df["X_Axis"] = df[["Sharp_Lap", "Sharp_Tenengrad", "Sharp_Brenner"]].mean(axis=1)
df["Y_Axis"] = df[["Struct_Entropy", "Struct_Coherence", "Struct_EdgeDensity"]].mean(axis=1)
df["Z_Axis"] = 1.0 - df[["NSS_Var", "Exposure_Dev", "Exposure_Clip"]].mean(axis=1)


axis_stack = np.vstack([df["X_Axis"], df["Y_Axis"], df["Z_Axis"]])
w = np.std(axis_stack, axis=1)
w = w / (np.sum(w) + 1e-9)

df["Composite_Score"] = (
    w[0]*df["X_Axis"] +
    w[1]*df["Y_Axis"] +
    w[2]*df["Z_Axis"]
)

mu, sd, skv = df["Composite_Score"].mean(), df["Composite_Score"].std(), skew(df["Composite_Score"])
theta = mu - (1.5 + min(abs(skv)*2, 2.0)) * sd

df["Decision"] = np.where(df["Composite_Score"] >= theta, "Keep", "Flag")


csv_path = os.path.join(output_dir, "Quality_Report.csv")
df.to_csv(csv_path, index=False)


plt.figure(figsize=(10,6))
plt.hist(df["Composite_Score"], bins=50, color="teal", alpha=0.75)
plt.axvline(theta, color="red", linestyle="--", linewidth=2)
plt.xlabel("Composite Score")
plt.ylabel("Image Count")
plt.title("Distribution of Image Quality Scores")
plt.tight_layout()

hist_path = os.path.join(plot_dir, "threshold_histogram.png")
plt.savefig(hist_path, dpi=300)
plt.close()


flagged_df = df[df["Decision"]=="Flag"].sort_values("Composite_Score")

if len(flagged_df) > 0:
    n = min(16, len(flagged_df))
    rows = (n + 3)//4
    fig, axes = plt.subplots(rows, 4, figsize=(15, rows*4))
    axes = axes.flatten()

    for i in range(n):
        img = cv2.cvtColor(cv2.imread(flagged_df.iloc[i]["path"]), cv2.COLOR_BGR2RGB)
        axes[i].imshow(img)
        axes[i].set_title(f"{flagged_df.iloc[i]['Composite_Score']:.3f}")
        axes[i].axis("off")

    for j in range(i+1, len(axes)):
        axes[j].axis("off")

    plt.tight_layout()
    montage_path = os.path.join(plot_dir, "flagged_images.png")
    plt.savefig(montage_path, dpi=300)
    plt.close()
else:
    print("No flagged images")


fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection="3d")

keep = df[df["Decision"]=="Keep"]
flag = df[df["Decision"]=="Flag"]

ax.scatter(keep["X_Axis"], keep["Y_Axis"], keep["Z_Axis"],
           c="teal", alpha=0.6, s=15, label="Keep")
ax.scatter(flag["X_Axis"], flag["Y_Axis"], flag["Z_Axis"],
           c="red", marker="x", s=50, label="Flagged")

wx, wy, wz = w
x = np.linspace(0,1,120)
y = np.linspace(0,1,120)
X, Y = np.meshgrid(x, y)
Z = (theta - wx*X - wy*Y) / (wz + 1e-9)
Z[(Z < 0) | (Z > 1)] = np.nan

ax.plot_surface(
    X, Y, Z,
    color="gray",
    alpha=0.25,
    edgecolor="none",
    antialiased=True,
    shade=True
)

ax.set_xlim(0,1); ax.set_ylim(0,1); ax.set_zlim(0,1)
ax.set_xlabel("Sharpness (X)")
ax.set_ylabel("Structural Information (Y)")
ax.set_zlabel("Perceptual Reliability (Z)")
ax.set_title(r"Decision Hyperplane: $w_xX + w_yY + w_zZ = \theta$")
ax.legend()
ax.view_init(elev=25, azim=-60)

static_path = os.path.join(plot_dir, "decision_manifold_static.png")
plt.tight_layout()
plt.savefig(static_path, dpi=300)
plt.close()



for fpath in tqdm(keep["path"], desc="Saved cleaned dataset"):
    rel = os.path.relpath(fpath, data_path)
    dst = os.path.join(cleaned_dir, rel)
    os.makedirs(os.path.dirname(dst), exist_ok=True)
    shutil.copy2(fpath, dst)



print(f"Total Images   : {len(df)}")
print(f"Images Kept    : {len(keep)}")
print(f"Images Flagged : {len(flag)} ({len(flag)/len(df)*100:.2f}%)")
print(f"Threshold θ    : {theta:.4f}")
print(f"Weights        : X={w[0]:.3f}, Y={w[1]:.3f}, Z={w[2]:.3f}")

